{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GA3C-CADRL\n",
    "#### GPU/CPU Asynchronous Advantage Actor-Critic for Collision Avoidance with Deep Reinforcement Learning\n",
    "Michael Everett, Yu Fan Chen, and Jonathan P. How<br>\n",
    "Manuscript submitted to 2018 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Objective:** This goal of this notebook is to explain how to use our code, enabling other researchers to test and compare against the results presented in the paper. After reading this notebook, it should also be clear how our code could be implemented on your own system (i.e. what format you should provide as input, and what information you'll get as output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create an instance of an Agent\n",
    "The most important class is Agent, which has attributes such as radius and position, and methods such as find_next_action. The environment is made up of several Agents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import agent\n",
    "import network\n",
    "import util\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load trained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Scale of 0 disables regularizer.\n",
      "INFO:tensorflow:Restoring parameters from ../checkpoints/network_01900000\n"
     ]
    }
   ],
   "source": [
    "possible_actions = network.Actions()\n",
    "num_actions = possible_actions.num_actions\n",
    "nn = network.NetworkVP_rnn(network.Config.DEVICE, 'network', num_actions)\n",
    "nn.simple_load('../checkpoints/network_01900000')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set current state of host agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_x = 2\n",
    "start_y = 5\n",
    "goal_x = 3\n",
    "goal_y = 2\n",
    "radius = 0.5\n",
    "pref_speed = 1.2\n",
    "heading_angle = 0\n",
    "index = 0\n",
    "v_x = 0\n",
    "v_y = 0\n",
    "\n",
    "host_agent = agent.Agent(start_x, start_y, goal_x, goal_y, radius, pref_speed, heading_angle, index)\n",
    "host_agent.vel_global_frame = np.array([v_x, v_y])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set current state of other agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample observation data in a format easily generated from sensors\n",
    "other_agents_x = [-1,-2,-3]\n",
    "other_agents_y = [2,3,4]\n",
    "other_agents_r = [0.5, 0.4, 0.3]\n",
    "other_agents_vx = [1.0, 0.6, 0.2]\n",
    "other_agents_vy = [0.0, 0.6, 0.8]\n",
    "num_other_agents = len(other_agents_x)\n",
    "\n",
    "# Create Agent objects for each observed dynamic obstacle\n",
    "other_agents = []\n",
    "for i in range(num_other_agents):\n",
    "    x = other_agents_x[i]; y = other_agents_y[i]\n",
    "    v_x = other_agents_vx[i]; v_y = other_agents_vy[i]\n",
    "    radius = other_agents_r[i]\n",
    "    \n",
    "    # dummy info - unobservable states not used by NN, just needed to create Agent object\n",
    "    heading_angle = np.arctan2(v_y, v_x) \n",
    "    pref_speed = np.linalg.norm(np.array([v_x, v_y]))\n",
    "    goal_x = x + 5.0; goal_y = y + 5.0\n",
    "    \n",
    "    other_agents.append(agent.Agent(x, y, goal_x, goal_y, radius, pref_speed, heading_angle, i+1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert agent states into observation vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs = host_agent.observe(other_agents)[1:]\n",
    "obs = np.expand_dims(obs, axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query the policy based on observation vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "action: [ 1.2        -0.52359878]\n"
     ]
    }
   ],
   "source": [
    "predictions = nn.predict_p(obs, None)[0]\n",
    "raw_action = possible_actions.actions[np.argmax(predictions)]\n",
    "action = np.array([host_agent.pref_speed*raw_action[0], util.wrap(raw_action[1] + host_agent.heading_global_frame)])\n",
    "print \"action:\", action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
